{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tensorflow 在训练期间使用checkpoint保存模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用程序检查点（checkpoint）技术，可以在训练过程中保存模型：\n",
    "* 训练程序崩溃，也不需要从头开始训练，加载一个检查点模型继续开始训练\n",
    "* 预估服务可以加载一个检查点模型实现模型更新"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow import keras"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. 读取数据构建模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 读取数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()\n",
    "\n",
    "train_labels = train_labels[:10000]\n",
    "test_labels = test_labels[:10000]\n",
    "\n",
    "train_images = train_images[:10000].reshape(-1, 28 * 28) / 255.0\n",
    "test_images = test_images[:10000].reshape(-1, 28 * 28) / 255.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([5, 0, 4, 1, 9], dtype=uint8)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 这是个10分类的训练任务\n",
    "train_labels[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 定义简单模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义一个简单的序列模型\n",
    "def create_model():\n",
    "    model = tf.keras.models.Sequential([\n",
    "        keras.layers.Dense(512, activation='relu', input_shape=(784,)),\n",
    "        keras.layers.Dense(10, activation='softmax')\n",
    "    ])\n",
    "\n",
    "    model.compile(optimizer='rmsprop',\n",
    "                loss='sparse_categorical_crossentropy',\n",
    "                metrics=['accuracy'])\n",
    "\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. 在训练期间保存模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "tf.keras.callbacks.ModelCheckpoint 允许在训练的过程中和结束时回调保存的模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建一个基本的模型实例\n",
    "model = create_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建一个保存模型权重的回调\n",
    "checkpoint_path = \"./traing_ckpt/cp_{epoch:02d}.ckpt\"\n",
    "cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,\n",
    "                                                 save_weights_only=True,\n",
    "                                                 verbose=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "312/313 [============================>.] - ETA: 0s - loss: 0.4026 - accuracy: 0.8807\n",
      "Epoch 00001: saving model to ./traing_ckpt/cp_01.ckpt\n",
      "313/313 [==============================] - 3s 10ms/step - loss: 0.4024 - accuracy: 0.8808 - val_loss: 0.2605 - val_accuracy: 0.9197\n",
      "Epoch 2/10\n",
      "312/313 [============================>.] - ETA: 0s - loss: 0.1787 - accuracy: 0.9482\n",
      "Epoch 00002: saving model to ./traing_ckpt/cp_02.ckpt\n",
      "313/313 [==============================] - 3s 9ms/step - loss: 0.1785 - accuracy: 0.9483 - val_loss: 0.2023 - val_accuracy: 0.9397\n",
      "Epoch 3/10\n",
      "311/313 [============================>.] - ETA: 0s - loss: 0.1183 - accuracy: 0.9651\n",
      "Epoch 00003: saving model to ./traing_ckpt/cp_03.ckpt\n",
      "313/313 [==============================] - 3s 11ms/step - loss: 0.1181 - accuracy: 0.9651 - val_loss: 0.2258 - val_accuracy: 0.9346\n",
      "Epoch 4/10\n",
      "308/313 [============================>.] - ETA: 0s - loss: 0.0819 - accuracy: 0.9754\n",
      "Epoch 00004: saving model to ./traing_ckpt/cp_04.ckpt\n",
      "313/313 [==============================] - 4s 13ms/step - loss: 0.0819 - accuracy: 0.9754 - val_loss: 0.1833 - val_accuracy: 0.9460\n",
      "Epoch 5/10\n",
      "313/313 [==============================] - ETA: 0s - loss: 0.0569 - accuracy: 0.9837\n",
      "Epoch 00005: saving model to ./traing_ckpt/cp_05.ckpt\n",
      "313/313 [==============================] - 4s 14ms/step - loss: 0.0569 - accuracy: 0.9837 - val_loss: 0.1748 - val_accuracy: 0.9525\n",
      "Epoch 6/10\n",
      "311/313 [============================>.] - ETA: 0s - loss: 0.0392 - accuracy: 0.9886\n",
      "Epoch 00006: saving model to ./traing_ckpt/cp_06.ckpt\n",
      "313/313 [==============================] - 3s 8ms/step - loss: 0.0391 - accuracy: 0.9887 - val_loss: 0.1720 - val_accuracy: 0.9511\n",
      "Epoch 7/10\n",
      "313/313 [==============================] - ETA: 0s - loss: 0.0260 - accuracy: 0.9923\n",
      "Epoch 00007: saving model to ./traing_ckpt/cp_07.ckpt\n",
      "313/313 [==============================] - 3s 10ms/step - loss: 0.0260 - accuracy: 0.9923 - val_loss: 0.1986 - val_accuracy: 0.9511\n",
      "Epoch 8/10\n",
      "308/313 [============================>.] - ETA: 0s - loss: 0.0198 - accuracy: 0.9934\n",
      "Epoch 00008: saving model to ./traing_ckpt/cp_08.ckpt\n",
      "313/313 [==============================] - 3s 10ms/step - loss: 0.0198 - accuracy: 0.9934 - val_loss: 0.1778 - val_accuracy: 0.9569\n",
      "Epoch 9/10\n",
      "311/313 [============================>.] - ETA: 0s - loss: 0.0112 - accuracy: 0.9970\n",
      "Epoch 00009: saving model to ./traing_ckpt/cp_09.ckpt\n",
      "313/313 [==============================] - 3s 9ms/step - loss: 0.0112 - accuracy: 0.9970 - val_loss: 0.1720 - val_accuracy: 0.9612\n",
      "Epoch 10/10\n",
      "313/313 [==============================] - ETA: 0s - loss: 0.0076 - accuracy: 0.9981\n",
      "Epoch 00010: saving model to ./traing_ckpt/cp_10.ckpt\n",
      "313/313 [==============================] - 3s 8ms/step - loss: 0.0076 - accuracy: 0.9981 - val_loss: 0.1980 - val_accuracy: 0.9564\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7fd8c5fe8490>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 训练模型，训练过程中每个epoch会保存checkpoint\n",
    "model.fit(train_images, \n",
    "          train_labels,  \n",
    "          epochs=10,\n",
    "          validation_data=(test_images,test_labels),\n",
    "          callbacks=[cp_callback])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "313/313 - 1s - loss: 0.1980 - accuracy: 0.9564\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.1979878842830658, 0.9563999772071838]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.evaluate(test_images,  test_labels, verbose=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. 使用checkpoint文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建一个新model\n",
    "new_model = create_model()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 加载模型预估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fd8c5ed47d0>"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 加载权重，不需要写.index等后缀\n",
    "new_model.load_weights(\"./traing_ckpt/cp_09.ckpt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "313/313 - 0s - loss: 0.1720 - accuracy: 0.9612\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.17197643220424652, 0.9611999988555908]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 直接进入评估\n",
    "new_model.evaluate(test_images,  test_labels, verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1.8646621e-11, 4.0034703e-15, 6.1754810e-09, 1.1134078e-05,\n",
       "        1.9557399e-16, 2.0905860e-11, 7.6235390e-19, 9.9998891e-01,\n",
       "        6.6065271e-11, 2.8822367e-09],\n",
       "       [1.8763623e-12, 2.7294575e-11, 1.0000000e+00, 5.7642353e-08,\n",
       "        1.0360668e-21, 3.5209656e-12, 2.5003152e-10, 1.4259344e-16,\n",
       "        2.9851371e-10, 8.9087966e-21],\n",
       "       [4.0501280e-10, 9.9966383e-01, 1.7066914e-04, 2.2081480e-05,\n",
       "        3.0094197e-05, 2.1883482e-08, 1.5438196e-06, 2.7922259e-05,\n",
       "        8.3738938e-05, 8.2594489e-08]], dtype=float32)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 直接进行预估\n",
    "new_model.predict(test_images[:3])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 继续开始训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/3\n",
      "313/313 [==============================] - 3s 10ms/step - loss: 0.0085 - accuracy: 0.9977 - val_loss: 0.1817 - val_accuracy: 0.9599\n",
      "Epoch 2/3\n",
      "313/313 [==============================] - 4s 12ms/step - loss: 0.0064 - accuracy: 0.9985 - val_loss: 0.1862 - val_accuracy: 0.9617\n",
      "Epoch 3/3\n",
      "313/313 [==============================] - 4s 12ms/step - loss: 0.0033 - accuracy: 0.9992 - val_loss: 0.1963 - val_accuracy: 0.9600\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7fd8bc0a1490>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 使用新的回调训练模型，loss会接着当前状态继续训练\n",
    "new_model.fit(train_images, \n",
    "          train_labels,  \n",
    "          epochs=3,\n",
    "          validation_data=(test_images,test_labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
